package middleware

import (
	"bytes"
	"io/ioutil"
	"time"

	"github.com/gin-gonic/gin"
	"github.com/sirupsen/logrus"
)

type responseBodyWriter struct {
	gin.ResponseWriter
	body *bytes.Buffer
}

func (w responseBodyWriter) Write(b []byte) (int, error) {
	w.body.Write(b)
	return w.ResponseWriter.Write(b)
}

func (w responseBodyWriter) WriteString(s string) (int, error) {
	w.body.WriteString(s)
	return w.ResponseWriter.WriteString(s)
}

type LoggerField map[string]func(*gin.Context) string

// Logger use for gin middleware
func Logger(fields ...LoggerField) func(*gin.Context) {
	return func(c *gin.Context) {
		startTime := time.Now()
		body, _ := c.GetRawData()
		c.Request.Body = ioutil.NopCloser(bytes.NewBuffer(body)) // 重写回body

		// wrap gin.ResponseWriter
		w := &responseBodyWriter{body: &bytes.Buffer{}, ResponseWriter: c.Writer}
		c.Writer = w

		c.Next()
		cost := time.Since(startTime)

		log := logrus.WithFields(logrus.Fields{
			"cost_string":     cost.String(),
			"cost":            cost.Milliseconds(),
			"error":           c.Errors.String(),
			"client_ip":       c.ClientIP(),
			"req_uri":         c.Request.RequestURI,
			"req_method":      c.Request.Method,
			"host":            c.Request.Host,
			"request_referer": c.Request.Referer(),
			"req_body":        string(body),
			"response":        w.body.String(),
		})
		for _, field := range fields {
			for key, gValue := range field {
				log.WithField(key, gValue(c))
			}
		}

		if c.Errors.String() == "" {
			log.Info()
		} else {
			log.Error()
		}
	}
}

func GetUserIDField(userIDKey string) LoggerField {
	return WarpContextField(userIDKey, "user")
}

func GetTokenField(tokenKey string) LoggerField {
	return WarpContextField(tokenKey, "access_token")
}

func WarpContextField(contextKey, key string) LoggerField {
	gv := func(c *gin.Context) string {
		return c.GetString(contextKey)
	}
	return LoggerField{key: gv}
}
